/*
 * Lincheck
 *
 * Copyright (C) 2019 - 2025 JetBrains s.r.o.
 *
 * This Source Code Form is subject to the terms of the
 * Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed
 * with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
 */

package org.jetbrains.kotlinx.lincheck.trace

import org.jetbrains.lincheck.trace.replaceNestedClassDollar
import org.jetbrains.lincheck.util.AnalysisProfile
import org.jetbrains.lincheck.trace.*


internal fun SingleThreadedTable<TraceNode>.compressTrace() = this
    .compressSyntheticFieldAccess()
    .compressSuspendImpl()
    .compressDefaultPairs()
    .compressAccessPairs()
    .compressUserThreadRun()
    .compressThreadStart()
    .removeCoroutinesCoreSuffix()
    .removeAssertionsKtOwnerName()
    .compressInlineIV()
    .compressDollarThis()
    .replaceNestedClassDollar()
    .compressLambdaCaptureSyntheticField()
    .compressVolatileDollar()

/**
 * Optimize stack trace element string representation
 */
internal fun StackTraceElement.compress(): String = this.toString()
    .removePackages()
    .removeStackTraceNestedClassDollarSigns()

/**
 * Compresses `receive$suspendImpl` calls.
 * 
 * These calls are part of suspend fun internals, but are not part of user code.
 * This function removes the `$suspendImpl` call and moves all its children to the parent.
 */
private fun SingleThreadedTable<TraceNode>.compressSuspendImpl() = compressNodes { node ->
    val singleChild = if (node.children.size == 1) node.children[0] else return@compressNodes node
    if (node !is CallNode || singleChild !is CallNode) return@compressNodes node
    if ("${node.tracePoint.methodName}\$suspendImpl" != singleChild.tracePoint.methodName) return@compressNodes node

    val newNode = node.copy()
    // trace grandchildren to children, inherit correct stackTraceElement, decrement depth
    singleChild.children.forEach {
        if (it.tracePoint is CodeLocationTracePoint) {
            (it.tracePoint as CodeLocationTracePoint).codeLocation = singleChild.tracePoint.codeLocation
        }
        newNode.addChild(it)
    }
    newNode
}

/**
 * Compresses `fun$default(...)` calls.
 *
 * Kotlin functions with default values are represented as two nested calls in the stack trace.
 *
 * For example:
 *
 * ```
 * A.callMe$default(A#1, 3, null, 2, null) at A.operation(A.kt:23)
 *   A.callMe(3, "Hey") at A.callMe$default(A.kt:27)
 * ```
 *
 * will be collapsed into:
 *
 * ```
 * A.callMe(3, "Hey") at A.operation(A.kt:23)
 * ```
 *
 */
private fun SingleThreadedTable<TraceNode>.compressDefaultPairs() = compressNodes { node ->
    val singleChild = if (node.children.size == 1) node.children[0] else return@compressNodes node
    if (node !is CallNode || singleChild !is CallNode) return@compressNodes node
    if (
        node.tracePoint.className != singleChild.tracePoint.className ||
        !isDefaultPair(node.tracePoint.methodName, singleChild.tracePoint.methodName)
    ) return@compressNodes node
    combineNodes(node, singleChild)
}

/**
 * Compresses `.access$` calls.
 *
 * The `.access$` methods are generated by the Kotlin compiler to access otherwise inaccessible members
 * (e.g., private) from lambdas, inner classes, etc.
 *
 * For example:
 *
 * ```
 * A.access$callMe() at A.operation(A.kt:N)
 *  A.callMe() at A.access$callMe(A.kt:N)
 * ```
 *
 * will be collapsed into:
 *
 * ```
 * A.callMe() at A.operation(A.kt:N)
 * ```
 *
 */
private fun SingleThreadedTable<TraceNode>.compressAccessPairs() = compressNodes { node ->
    val singleChild = if (node.children.size == 1) node.children[0] else return@compressNodes node
    if (node !is CallNode || singleChild !is CallNode) return@compressNodes node
    if (
        node.tracePoint.className != singleChild.tracePoint.className ||
        !isAccessPair(node.tracePoint.methodName, singleChild.tracePoint.methodName)
    ) return@compressNodes node
    combineNodes(node, singleChild)
}

/**
 * Combine trace node for `default` and `access` functions.
 * For more details check [isDefaultPair] and [isAccessPair].
 */
private fun combineNodes(parent: CallNode, child: CallNode): TraceNode {
    // TODO investigate why in rare cases return values are not equal #682
    if (parent.tracePoint.returnedValue == child.tracePoint.returnedValue
        && parent.tracePoint.thrownException == child.tracePoint.thrownException) {

        parent.tracePoint.methodName = child.tracePoint.methodName
        parent.tracePoint.parameters = child.tracePoint.parameters


        val newNode = parent.copy()
        child.children.forEach { newNode.addChild(it) }
        return newNode
    }

    return parent
}

/**
 * Compresses synthetic field access methods (`access$get` and `access$set`).
 *
 * These methods are generated by the Kotlin compiler when a lambda or inner class accesses a private field.
 * This function removes the synthetic access method and keeps only the actual field access event.
 *
 * For example:
 *
 * ```
 * A.access$get_field(A#1) at A.operation(A.kt:N)
 *   field ➜ value at A.access$get_field(A.kt:N)
 * ```
 *
 * will be collapsed into:
 *
 * ```
 * field ➜ value at A.operation(A.kt:N)
 * ```
 *
 * This is different from `fun$access`, which is addressed in [compressAccessPairs].
 *
 * @see AccessFieldRepresentationTest
 */
private fun SingleThreadedTable<TraceNode>.compressSyntheticFieldAccess() = compressNodes { node ->
    val singleChild = if (node.children.size == 1) node.children[0] else return@compressNodes node
    if (node !is CallNode || singleChild !is EventNode) return@compressNodes node
    if (!isSyntheticFieldAccess(node.tracePoint.methodName)) return@compressNodes node

    val point = singleChild.tracePoint
    if (point is ReadTracePoint) point.codeLocation = node.tracePoint.codeLocation
    if (point is WriteTracePoint) point.codeLocation = node.tracePoint.codeLocation

    singleChild
}

/**
 * Removes the `$kotlinx_coroutines_core` suffix from method names.
 *
 * This suffix is added by the Kotlin coroutines compiler plugin to distinguish between different versions
 * of the same method. This transformation makes the trace more readable by removing this implementation detail.
 *
 * For example:
 *
 * ```
 * A.someMethod$kotlinx_coroutines_core() at A.kt:10
 * ```
 *
 * will be transformed to:
 *
 * ```
 * A.someMethod() at A.kt:10
 * ```
 */
private fun SingleThreadedTable<TraceNode>.removeCoroutinesCoreSuffix() = compressNodes { node ->
    if (node is CallNode && node.tracePoint.methodName.hasCoroutinesCoreSuffix()) {
        node.tracePoint.methodName = node.tracePoint.methodName.removeCoroutinesCoreSuffix()
    }

    if (node.tracePoint is CodeLocationTracePoint && (node.tracePoint as CodeLocationTracePoint).stackTraceElement.methodName.hasCoroutinesCoreSuffix()) {
        val oldStackTraceElement = (node.tracePoint as CodeLocationTracePoint).stackTraceElement
        val newStackTraceElement = StackTraceElement(
            oldStackTraceElement.className,
            oldStackTraceElement.methodName.removeCoroutinesCoreSuffix(),
            oldStackTraceElement.fileName,
            oldStackTraceElement.lineNumber,
        )
        (node.tracePoint as CodeLocationTracePoint).stackTraceElement = newStackTraceElement
    }

    node
}

private fun SingleThreadedTable<TraceNode>.removeAssertionsKtOwnerName() = compressNodes { node ->
    if (node is CallNode && (node.tracePoint.ownerName == "kotlin.test.AssertionsKt" || node.tracePoint.className == "org.junit.jupiter.api.Assertions")) {
        node.tracePoint.updateOwnerName(null)
    }
    node
}

/**
 * Removes the lambda invocation line at the beginning of a user-defined thread trace.
 */
private fun SingleThreadedTable<TraceNode>.compressUserThreadRun() = compressNodes { node ->
    if (node !is CallNode || !node.tracePoint.isThreadStart) return@compressNodes node
    val child = if (node.children.size == 1) node.children[0] else return@compressNodes node

    if (child !is CallNode) return@compressNodes node
    if (!isUserThreadStart(node.tracePoint, child.tracePoint)) return@compressNodes node

    val newNode = node.copy()
    node.children.getOrNull(0)?.children?.forEach {
        newNode.addChild(it)
    }
    newNode
}

/**
 * When `thread() { ... }` is called it is represented as
 * ```
 * thread creation line: Thread#2 at A.fun(location)
 *     Thread#2.start()
 * ```
 * this function gets rid of the second line.
 * But only if it has been created with `thread(start = true)`
 */
private fun SingleThreadedTable<TraceNode>.compressThreadStart() = compressNodes { node ->
    if (node !is CallNode || !node.tracePoint.isThreadCreation() ) return@compressNodes node
    val firstChild = if (node.children.size == 1) node.children[0] else return@compressNodes node
    val secondChild = if (firstChild.children.size == 1) firstChild.children[0] else return@compressNodes node
    if (secondChild !is EventNode || !secondChild.tracePoint.isThreadStart()) return@compressNodes node

    val newNode = node.copy()
    newNode.addChild(secondChild)
    newNode
}

/**
 * Removes the `$iv` suffix from owner names in inline functions.
 *
 * The Kotlin compiler adds `$iv` suffix to variables in inline functions to avoid name conflicts.
 * This transformation makes the trace more readable by removing this implementation detail.
 *
 * For example:
 *
 * ```
 * cancellable$iv.getResult() at A.kt:10
 * ```
 *
 * will be transformed to:
 *
 * ```
 * cancellable.getResult() at A.kt:10
 * ```
 */
private fun SingleThreadedTable<TraceNode>.compressInlineIV() = compressNodes { node ->
    if (node !is CallNode || node.tracePoint.ownerName == null) return@compressNodes node
    node.tracePoint.updateOwnerName(node.tracePoint.ownerName!!.split(".").joinToString(".") { it.removeInlineIV() })
    node
}

/**
 * Removes `$this` owner names from method calls.
 *
 * In Kotlin, extension functions are compiled with a special `$this` receiver parameter.
 * This transformation makes the trace more readable by removing this implementation detail.
 *
 * For example:
 *
 * ```
 * $this.someMethod() at A.kt:10
 * ```
 *
 * will be transformed to:
 *
 * ```
 * someMethod() at A.kt:10
 * ```
 */
private fun SingleThreadedTable<TraceNode>.compressDollarThis() = compressNodes { node ->
    if (node !is CallNode || node.tracePoint.ownerName == null) return@compressNodes node
    if (node.tracePoint.ownerName.isExactDollarThis()) node.tracePoint.updateOwnerName(null)
    node
}

/**
 * Replaces dollar signs with dots in nested class names.
 *
 * In JVM bytecode, nested classes are represented with dollar signs (e.g., `OuterClass$InnerClass`).
 * This transformation makes the trace more readable by using the more familiar dot notation.
 *
 * For example:
 *
 * ```
 * OuterClass$InnerClass.method() at A.kt:10
 * ```
 *
 * will be transformed to:
 *
 * ```
 * OuterClass.InnerClass.method() at A.kt:10
 * ```
 */
private fun SingleThreadedTable<TraceNode>.replaceNestedClassDollar() = compressNodes { node ->
    if (node is CallNode && node.tracePoint.ownerName != null) {
        val newOwner = node.tracePoint.ownerName!!.replaceNestedClassDollar()
        node.tracePoint.updateOwnerName(newOwner)
    }
    node
}

/**
 * Simplifies the representation of captured variables in lambdas.
 *
 * When a lambda captures a variable, the Kotlin compiler creates a synthetic field named "element"
 * in a synthetic class named after the captured variable with a dollar sign prefix.
 * This transformation makes the trace more readable by using the original variable name.
 *
 * For example:
 *
 * ```
 * $capturedVar.element ➜ value at A.kt:10
 * ```
 *
 * will be transformed to:
 *
 * ```
 * capturedVar ➜ value at A.kt:10
 * ```
 */
private fun SingleThreadedTable<TraceNode>.compressLambdaCaptureSyntheticField() = compressNodes { node ->
    if (node is EventNode
        && node.tracePoint is ReadTracePoint
        && node.tracePoint.ownerRepresentation?.startsWith("$") == true
        && node.tracePoint.fieldName == "element"
    ) {
        node.tracePoint.updateFieldName(node.tracePoint.ownerRepresentation!!.removePrefix("$"))
        node.tracePoint.updateOwnerRepresentation(null)
    }

    if (node is EventNode
        && node.tracePoint is WriteTracePoint
        && node.tracePoint.ownerRepresentation?.startsWith("$") == true
        && node.tracePoint.fieldName == "element"
    ) {
        node.tracePoint.updateFieldName(node.tracePoint.ownerRepresentation!!.removePrefix("$"))
        node.tracePoint.updateOwnerRepresentation(null)
    }
    node
}

/**
 * Removes the `$volatile` suffix from owner names.
 *
 * The Kotlin compiler adds a `$volatile` suffix to fields that are marked as volatile.
 * This transformation makes the trace more readable by removing this implementation detail.
 *
 * For example:
 *
 * ```
 * someVar$volatile.someMethod() at A.kt:10
 * ```
 *
 * will be transformed to:
 *
 * ```
 * someVar.someMethod() at A.kt:10
 * ```
 *
 * @see AtomicReferencesNamesTests
 */
private fun SingleThreadedTable<TraceNode>.compressVolatileDollar() = compressNodes { node ->
    if (node is CallNode && node.tracePoint.ownerName != null) {
        node.tracePoint.updateOwnerName(node.tracePoint.ownerName!!.removeVolatileDollar())
    }

    // TODO this can be removed after IJTD-151 is merged.
    //  Could also be fixed in TraceNodes.kt but would cause unnecessary conflicts.
    if (node is EventNode && node.tracePoint is MethodCallTracePoint && node.tracePoint.ownerName != null) {
        node.tracePoint.updateOwnerName(node.tracePoint.ownerName!!.removeVolatileDollar())
    }

    node
}

/**
 * Removes package info in the stack trace element representation.
 */
private fun String.removePackages(): String {
    for (i in this.indices.reversed())
        if (this[i] == '/')
            return this.substring(i + 1 until this.length)
    return this
}

/**
 * Removes nested class dollar signs from stackTraceElement string representation.
 */
private fun String.removeStackTraceNestedClassDollarSigns(): String {
    val before = this.substringBefore('.')
    val after = this.substringAfter('.', "")
    if (after.isEmpty()) return before

    return "${before.replaceNestedClassDollar()}.$after"
}

internal fun SingleThreadedTable<TraceNode>.collapseLibraries(analysisProfile: AnalysisProfile) = compressNodes { node ->
    // if should not be hidden
    if (node !is CallNode || !analysisProfile.shouldBeHidden(node)) return@compressNodes node

    // if cannot be hidden (due to switch point)
    if (node.containsDescendant { it is EventNode && it.tracePoint is SwitchEventTracePoint }) 
        return@compressNodes node

    val newNode = node.copy()
    findSubTreesToBeShown(node, analysisProfile).forEach {  newNode.addChild(it) }
    return@compressNodes newNode
}

/**
 * Finds descendants that should not be hidden.
 * But not descendants of descendants, aka the roots of all subtrees that should be shown in the trace.
 */
private fun findSubTreesToBeShown(node: TraceNode, analysisProfile: AnalysisProfile): List<TraceNode> {
    if (node !is CallNode) return emptyList()
    if (!analysisProfile.shouldBeHidden(node)) return listOf(node)
    return node.children.map { findSubTreesToBeShown(it, analysisProfile) }.flatten()
}

private fun SingleThreadedTable<TraceNode>.compressNodes(compressionRule: (TraceNode) -> TraceNode) = map {
    it.map { it.compress(compressionRule) }
}

private fun TraceNode.compress(compressionRule: (TraceNode) -> TraceNode): TraceNode {
    val compressedNode = compressionRule(this)
    val newNode = compressedNode.copy()
    compressedNode.children.forEach { newNode.addChild(it.compress(compressionRule)) }
    return newNode
}

/**
 * Used to remove the lambda invocation line at the beginning of
 * a user-defined thread trace.
 */
private fun isUserThreadStart(currentTracePoint: MethodCallTracePoint, nextTracePoint: MethodCallTracePoint): Boolean =
    currentTracePoint.isThreadStart &&
    nextTracePoint.className == "kotlin.jvm.functions.Function0" &&
    nextTracePoint.methodName == "invoke"

private fun AnalysisProfile.shouldBeHidden(callNode: CallNode): Boolean = 
    shouldBeHidden(callNode.tracePoint.className, callNode.tracePoint.methodName)
